package cn.atcoder.air.client;

import cn.atcoder.air.exception.ClientTimeoutException;
import cn.atcoder.air.exception.MessageException;
import cn.atcoder.air.exception.NoProviderException;
import cn.atcoder.air.msg.Invocation;
import cn.atcoder.air.msg.MessageHeader;
import cn.atcoder.air.transport.ClientTransportFactory;
import cn.atcoder.air.util.DateUtils;
import io.netty.channel.Channel;

import java.util.Date;
import java.util.concurrent.CancellationException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

/**
 * @author yangjunda1
 * @description
 * @date 5/9/19 5:40 PM
 */
public class MessageFuture<V> implements Future<V> {

    private Object result;

    private final MessageHeader header;

    private Invocation invocationBody;

    private short waiters;
    /**
     * 用户设置的超时时间
     */
    private final int timeout;
    /**
     * Future生成时间
     */
    private final long genTime = System.currentTimeMillis();
    /**
     * Future已发送时间
     */
    private volatile long sentTime;

    private static final String UNCANCELLABLE = "UNCANCELLABLE";

    private static final CauseHolder CANCELLATION_CAUSE_HOLDER = new CauseHolder(new CancellationException());

    /**
     * 当前连接
     */
    private final Channel channel;

    public MessageFuture(MessageHeader header, int timeout, Channel channel) {
        this.header = header;
        this.timeout = timeout;
        this.channel = channel;
    }

    @Override
    public boolean cancel(boolean mayInterruptIfRunning) {
        Object result = this.result;
        if (isDone0(result, invocationBody) || result == UNCANCELLABLE) {
            return false;
        }

        synchronized (this) {
            // Allow only once.
            result = this.result;
            if (isDone0(result, invocationBody) || result == UNCANCELLABLE) {
                return false;
            }

            this.result = CANCELLATION_CAUSE_HOLDER;
            if (hasWaiters()) {
                notifyAll();
            }
        }
        return true;
    }

    @Override
    public boolean isCancelled() {
        return false;
    }

    @Override
    public boolean isDone() {
        return isDone0(result, invocationBody);
    }

    private static boolean isDone0(Object result, Invocation invocationBody) {
        if (!ClientTransportFactory.refer().isOpen()) {
            String errorMsg = "Execute method : ["+ invocationBody.getClazzName() +"." + invocationBody.getMethodName() + "] No alive provider of pinpoint address : [" + ClientTransportFactory.refer().getRemoteAddress().toString() + "]!";
            throw new NoProviderException(errorMsg);
        }
        return result != null && result != UNCANCELLABLE;
    }

    @Override
    public V get() throws InterruptedException {
        return get(timeout, TimeUnit.MILLISECONDS);
    }

    @Override
    public V get(long timeout, TimeUnit unit) throws InterruptedException {
        // 转为毫秒
        timeout = unit.toMillis(timeout);
        // 剩余时间
        long remaintime = timeout - (sentTime - genTime);
        // 没有剩余时间不等待
        if (remaintime <= 0) {
            // 直接看是否已经返回
            if (isDone()) {
                return getNow();
            }
        } else { // 等待剩余时间
            if (await(remaintime, TimeUnit.MILLISECONDS)) {
                return getNow();
            }
        }
        throw clientTimeoutException();
    }

    public V getNow() {
        Object result = this.result;
        // 本地异常
        if (result instanceof CauseHolder) {
            Throwable e = ((CauseHolder) result).cause;
            if (e instanceof MessageException) {
                MessageException rpcException = (MessageException) e;
                rpcException.setHeader(header);
                throw rpcException;
            } else {
                throw new MessageException(this.header, ((CauseHolder) result).cause);
            }
        }
        return (V) result;
    }

    public boolean await(long timeout, TimeUnit unit)
            throws InterruptedException {
        return await0(unit.toNanos(timeout), true);
    }

    private boolean await0(long timeoutNanos, boolean interruptable) throws InterruptedException {
        if (isDone()) {
            return true;
        }

        if (timeoutNanos <= 0) {
            return isDone();
        }

        if (interruptable && Thread.interrupted()) {
            throw new InterruptedException(toString());
        }

        long startTime = System.nanoTime();
        long waitTime = timeoutNanos;
        boolean interrupted = false;

        try {
            synchronized (this) {
                if (isDone()) {
                    return true;
                }

                if (waitTime <= 0) {
                    return isDone();
                }

                //checkDeadLock(); need this check?
                incWaiters();
                try {
                    for (; ; ) {
                        try {
                            wait(waitTime / 1000000, (int) (waitTime % 1000000));
                        } catch (InterruptedException e) {
                            if (interruptable) {
                                throw e;
                            } else {
                                interrupted = true;
                            }
                        }

                        if (isDone()) {
                            return true;
                        } else {
                            waitTime = timeoutNanos - (System.nanoTime() - startTime);
                            if (waitTime <= 0) {
                                return isDone();
                            }
                        }
                    }
                } finally {
                    decWaiters();
                }
            }
        } finally {
            if (interrupted) {
                Thread.currentThread().interrupt();
            }
        }
    }


    private static final class CauseHolder {
        final Throwable cause;

        private CauseHolder(Throwable cause) {
            this.cause = cause;
        }
    }

    private boolean hasWaiters() {
        return waiters > 0;
    }

    private void incWaiters() {
        if (waiters == Short.MAX_VALUE) {
            throw new IllegalStateException("too many waiters: " + this);
        }
        waiters++;
    }

    private void decWaiters() {
        waiters--;
    }

    public MessageFuture<V> setSuccess(V result) {
        if (setSuccess0(result)) {
//            notifyListeners();
            return this;
        }
        throw new IllegalStateException("complete already: " + this);
    }

    private boolean setSuccess0(V result) {
        if (isDone()) {
            return false;
        }

        synchronized (this) {
            // Allow only once.
            if (isDone()) {
                return false;
            }
            if (this.result == null) {

                this.result = result;
            }
            if (hasWaiters()) {
                notifyAll();
            }
        }
        return true;
    }


    public MessageFuture<V> setFailure(Throwable cause) {
        if (setFailure0(cause)) {
//            notifyListeners();
            return this;
        }
        throw new IllegalStateException("complete already: " + this, cause);
    }

    private boolean setFailure0(Throwable cause) {
        if (isDone()) {
            return false;
        }

        synchronized (this) {
            // Allow only once.
            if (isDone()) {
                return false;
            }

            result = new CauseHolder(cause);
            if (hasWaiters()) {
                notifyAll();
            }
        }
        return true;
    }

    /**
     * 构建超时异常
     * <p>
     * 是否扫描线程
     *
     * @return 异常ClientTimeoutException
     */
    public ClientTimeoutException clientTimeoutException() {
        Date now = new Date();
        String errorMsg = (sentTime > 0 ? "Waiting return response timeout"
                : "Consumer send request timeout")
                + ". Start time: " + DateUtils.dateToMillisStr(new Date(genTime))
                + ", End time: " + DateUtils.dateToMillisStr(now)
                + ((sentTime > 0 ?
                ", Client elapsed: " + (sentTime - genTime)
                        + "ms, Server elapsed: " + (now.getTime() - sentTime)
                : ", Client elapsed: " + (now.getTime() - genTime))
                + "ms, Timeout: " + timeout
                + "ms, MsgHeader: " + this.header
                + ", Channel: " + channel.localAddress() + "->" + channel.remoteAddress());
        return new ClientTimeoutException(errorMsg);
    }

    /**
     * 构建超时异常
     * <p>
     * 是否扫描线程
     *
     * @return 异常ClientTimeoutException
     */
    public NoProviderException noProviderException() {
        Date now = new Date();
        String errorMsg = "Execute method : ["+ invocationBody.getClazzName() +"." + invocationBody.getMethodName() + "] No alive provider of pinpoint address : [" + ClientTransportFactory.refer().getRemoteAddress().toString() + "]!";
        return new NoProviderException(errorMsg);
    }

    public Object getResult() {
        return result;
    }

    public void setResult(Object result) {
        this.result = result;
    }

    public MessageHeader getHeader() {
        return header;
    }

    public short getWaiters() {
        return waiters;
    }

    public void setWaiters(short waiters) {
        this.waiters = waiters;
    }

    public int getTimeout() {
        return timeout;
    }

    public long getGenTime() {
        return genTime;
    }

    public long getSentTime() {
        return sentTime;
    }

    public void setSentTime(long sentTime) {
        this.sentTime = sentTime;
    }

    public Channel getChannel() {
        return channel;
    }

    public Invocation getInvocationBody() {
        return invocationBody;
    }

    public void setInvocationBody(Invocation invocationBody) {
        this.invocationBody = invocationBody;
    }
}
